{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/S22/anaconda3/envs/BasicTS/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "PROJECT_DIR = os.path.abspath(os.path.abspath('') + \"/../..\")\n",
    "os.chdir(PROJECT_DIR)\n",
    "\n",
    "from basicts import launch_runner, BaseRunner\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "MODEL_NAME = 'MLP'\n",
    "DATASET_NAME = 'Pulse'\n",
    "BATCH_SIZE = 32\n",
    "GPUS = '2'\n",
    "\n",
    "cfg_path = 'baselines/{0}/{1}.py'.format(MODEL_NAME, DATASET_NAME) # NOTE: use relative path\n",
    "ckpt_path = '' # NOTE: use relative path\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def inference(cfg: dict, runner: BaseRunner, ckpt: str = None, batch_size: int = 1):\n",
    "    # init logger\n",
    "    runner.init_logger(logger_name='easytorch-inference', log_file_name='validate_result')\n",
    "    # init model\n",
    "    cfg.TEST.DATA.BATCH_SIZE = batch_size\n",
    "    runner.model.eval()\n",
    "    runner.setup_graph(cfg=cfg, train=False)\n",
    "    # load model checkpoint\n",
    "    runner.load_model(ckpt_path=ckpt)\n",
    "    # test\n",
    "    runner.init_test(cfg)\n",
    "    global results\n",
    "    results = runner.test()\n",
    "\n",
    "launch_runner(cfg_path, inference, (ckpt_path, BATCH_SIZE), devices=GPUS)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "prediction = results['prediction'].detach().cpu().numpy()\n",
    "target = results['target'].detach().cpu().numpy()\n",
    "inputs = results['inputs'].detach().cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "prediction.shape # num_samples, len_prediction, num_time_series, num_channels\n",
    "import random\n",
    "sample_id = random.randint(0, prediction.shape[0])\n",
    "time_series_id = random.randint(0, prediction.shape[2]-1)\n",
    "channel_id = random.randint(0, prediction.shape[3]-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot history, targets, and predictions\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "plt.figure(figsize=(10, 5))\n",
    "\n",
    "plt.plot(np.arange(inputs.shape[1]), inputs[sample_id, :, time_series_id, channel_id], label='history')\n",
    "plt.plot(np.arange(inputs.shape[1], inputs.shape[1]+target.shape[1]), target[sample_id, :, time_series_id, channel_id], label='target')\n",
    "plt.plot(np.arange(inputs.shape[1], inputs.shape[1]+prediction.shape[1]), prediction[sample_id, :, time_series_id, channel_id], label='prediction')\n",
    "plt.legend()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "BasicTS",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
